package com.github.hgkmail.hello.service;

import com.github.hgkmail.hello.dao.employees.EmployeesDao;
import com.github.hgkmail.hello.dao.employees.SalaryDao;
import com.github.hgkmail.hello.dao.employees.SalaryStatsDao;
import com.github.hgkmail.hello.entity.employees.Employee;
import com.github.hgkmail.hello.entity.employees.Salary;
import com.github.hgkmail.hello.entity.employees.SalaryStatsRecord;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.commons.collections.CollectionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import java.math.BigDecimal;
import java.sql.Date;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

@Service
public class SalaryStatsService {
    private static final Logger logger = LoggerFactory.getLogger(SalaryStatsService.class);

    private static final int POISON_PILL_EMP_NO = -100;
    private static final SalaryStatsRecord POISON_PILL = new SalaryStatsRecord(-100);

    private static AtomicInteger producerConsumerStatsRecordCount = new AtomicInteger(0);
    private static AtomicInteger multiThreadStatsRecordCount = new AtomicInteger(0);

    @Autowired
    EmployeesDao employeesDao;

    @Autowired
    SalaryDao salaryDao;

    @Autowired
    SalaryStatsDao salaryStatsDao;

    /**
     * 最简单的统计方式
     * @return 生成统计记录的条数
     */
    public int simpleStats() {
        //遍历员工
        int startEmpNo=0;
        int pageSize = 100;
        int recordCount = 0;
        while (true) {
            List<Employee> employeeList = employeesDao.getByPage(startEmpNo, pageSize);
            if (CollectionUtils.isEmpty(employeeList)) {
                break;
            }
            startEmpNo = employeeList.stream().map(Employee::getEmp_no).max(Comparator.comparingInt(x -> x)).orElse(-1);
            if (startEmpNo<0) {
                logger.info("startEmpNo<0 {}", startEmpNo);
                break;
            }

            for(Employee employee:employeeList) {
                SalaryStatsRecord record = getEmployeeSalaryStatsRecord(employee.getEmp_no());
                if (Objects.nonNull(record)) {
                    recordCount += salaryStatsDao.saveRecord(record);
                }
            }
        }

        return recordCount;
    }

    public int multiThreadStats() {
        multiThreadStatsRecordCount.set(0);

        int pageSize = 100;
        int employeeCount = employeesDao.countAll();
        int taskNum = employeeCount/pageSize+1;
        deleteOldRecords();

        ExecutorService pool = new ThreadPoolExecutor(10, 10, 5, TimeUnit.SECONDS,
                new ArrayBlockingQueue<>(taskNum),
                new ThreadFactoryBuilder().setNameFormat("multiThreadStats-producerPool-%d").build());

        int startEmpNo=0;
        while (true) {
            List<Employee> employeeList = employeesDao.getByPage(startEmpNo, pageSize);
            if (CollectionUtils.isEmpty(employeeList)) {
                break;
            }
            startEmpNo = employeeList.stream().map(Employee::getEmp_no).max(Comparator.comparingInt(x -> x)).orElse(-1);
            if (startEmpNo < 0) {
                logger.info("empNo<0 {}", startEmpNo);
                break;
            }
            pool.submit(new ReadAndWriteTask(employeeList));
        }
        awaitTerminationAfterShutdown(pool);

        return multiThreadStatsRecordCount.get();
    }

    //和go的写法一模一样
    public int producerConsumerStats() {
        producerConsumerStatsRecordCount.set(0);

        int employeeCount = employeesDao.countAll();
        deleteOldRecords();
        int writerCount = 2;

        ExecutorService producerPool = new ThreadPoolExecutor(10, 10, 5, TimeUnit.SECONDS,
                new LinkedBlockingQueue<>(employeeCount),
                new ThreadFactoryBuilder().setNameFormat("producerConsumerStats-producerPool-%d").build());
        ExecutorService consumerPool = new ThreadPoolExecutor(writerCount, writerCount, 5, TimeUnit.SECONDS,
                new LinkedBlockingQueue<>(writerCount),
                new ThreadFactoryBuilder().setNameFormat("producerConsumerStats-consumerPool-%d").build());
        LinkedBlockingDeque<SalaryStatsRecord> queueToSave = new LinkedBlockingDeque<>(employeeCount);

        //消费者
        //写入到数据库
        for (int i = 0; i < writerCount; i++) {
            consumerPool.submit(new WriteRecordToDB(queueToSave));
        }

        //生产者
        //遍历员工
        int startEmpNo=0;
        int pageSize = 100;
        while (true) {
            List<Employee> employeeList = employeesDao.getByPage(startEmpNo, pageSize);
            if (CollectionUtils.isEmpty(employeeList)) {
                break;
            }
            startEmpNo = employeeList.stream().map(Employee::getEmp_no).max(Comparator.comparingInt(x -> x)).orElse(-1);
            if (startEmpNo<0) {
                logger.info("empNo<0 {}", startEmpNo);
                break;
            }

            for(Employee employee:employeeList) {
                producerPool.submit(new GetEmployeeSalaryStatsRecord(employee.getEmp_no(), queueToSave));
            }
        }
        awaitTerminationAfterShutdown(producerPool);

        for (int i = 0; i < writerCount; i++) {
            try {
                queueToSave.put(POISON_PILL);
            } catch (InterruptedException e) {
                logger.warn("毒丸放入队列异常 {}", e);
                Thread.currentThread().interrupt();

            }
        }
        while (!queueToSave.isEmpty()) {
            try {
                Thread.sleep(1000);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }
        awaitTerminationAfterShutdown(consumerPool);

        //原子变量是线程安全
        return producerConsumerStatsRecordCount.get();
    }

    private class GetEmployeeSalaryStatsRecord implements Runnable {

        int empNo;
        BlockingQueue<SalaryStatsRecord> queueToSave;

        public GetEmployeeSalaryStatsRecord(int empNo, BlockingQueue<SalaryStatsRecord> queueToSave) {
            this.empNo = empNo;
            this.queueToSave = queueToSave;
        }

        @Override
        public void run() {
            SalaryStatsRecord record = getEmployeeSalaryStatsRecord(empNo);
            if (Objects.nonNull(record)) {
                try {
                    queueToSave.put(record);
                } catch (InterruptedException e) {
                    logger.warn("统计结果放入队列异常 empNo-{} ex-{}", empNo, e);
                    Thread.currentThread().interrupt();
                }
            }
        }
    }

    private class WriteRecordToDB implements Runnable {

        private static final int NUM_TO_SAVE = 20;
        BlockingQueue<SalaryStatsRecord> queueToSave;

        public WriteRecordToDB(BlockingQueue<SalaryStatsRecord> queueToSave) {
            this.queueToSave = queueToSave;
        }

        @Override
        public void run() {
            ArrayList<SalaryStatsRecord> records = new ArrayList<>(NUM_TO_SAVE);
            while (true) {
                try {
                    SalaryStatsRecord record = queueToSave.take();
                    if (record.getEmp_no()==POISON_PILL_EMP_NO) {
                        //毒丸
                        break;
                    } else {
                        records.add(record);
                    }

                    if (records.size()>= NUM_TO_SAVE) {
                        int affect = salaryStatsDao.batchSaveRecord(records);
                        records.clear();
                        producerConsumerStatsRecordCount.addAndGet(affect);
                    }
                } catch (InterruptedException e) {
                    logger.warn("从队列取记录异常 {}", e);
                    Thread.currentThread().interrupt();
                }
            }

            if (CollectionUtils.isNotEmpty(records)) {
                int affect = salaryStatsDao.batchSaveRecord(records);
                records.clear();
                producerConsumerStatsRecordCount.addAndGet(affect);
            }
        }//run

    }

    private class ReadAndWriteTask implements Runnable {
        List<Employee> employeeList;

        public ReadAndWriteTask(List<Employee> employeeList) {
            this.employeeList = employeeList;
        }

        @Override
        public void run() {
            List<SalaryStatsRecord> recordList = new ArrayList<>();
            for (Employee employee:employeeList) {
                SalaryStatsRecord salaryStatsRecord = getEmployeeSalaryStatsRecord(employee.getEmp_no());
                if (Objects.nonNull(salaryStatsRecord)) {
                    recordList.add(salaryStatsRecord);
                }
            }
            int affect = salaryStatsDao.batchSaveRecord(recordList);
            multiThreadStatsRecordCount.addAndGet(affect);
        }
    }

    private SalaryStatsRecord getEmployeeSalaryStatsRecord(int empNo) {
        List<Salary> salaries = salaryDao.getSalaryByEmpNo(empNo);
        if (CollectionUtils.isEmpty(salaries)) {
            return null;
        }

        SalaryStatsRecord record = new SalaryStatsRecord(empNo);

        Date startDate = salaries.stream().map(Salary::getFrom_date).min(Comparator.comparing(Date::getTime)).orElse(new Date(0));
        record.setStart_date(startDate);

        Date endDate = salaries.stream().map(Salary::getTo_date).max(Comparator.comparing(Date::getTime)).orElse(new Date(0));
        record.setEnd_date(endDate);

        Integer totalSalary = salaries.stream().map(Salary::getSalary).reduce(0, (acc, cur) -> acc + cur);
        record.setTotal_salary(BigDecimal.valueOf(totalSalary));

        Integer yearNum = salaries.size();
        record.setYear_num(yearNum);

        BigDecimal avgSalary = BigDecimal.valueOf(totalSalary / (double) yearNum);
        record.setAverage_salary(avgSalary);

        return record;
    }

    private void awaitTerminationAfterShutdown(ExecutorService threadPool) {
        //等待所有已提交的任务执行完毕
        threadPool.shutdown();
        try {
            //等待线程池关闭
            if (!threadPool.awaitTermination(30, TimeUnit.SECONDS)) {
                threadPool.shutdownNow();
            }
        } catch (InterruptedException ex) {
            logger.info("等待线程池关闭异常 {}", ex);

            //必须重新抛出InterruptedException
            Thread.currentThread().interrupt();
        }
    }

    //必须先删除之前的记录
    public void deleteOldRecords() {
        int startEmpNo=0;
        int pageSize = 1000;
        while (true) {
            List<Employee> employeeList = employeesDao.getByPage(startEmpNo, pageSize);
            if (CollectionUtils.isEmpty(employeeList)) {
                break;
            }
            startEmpNo = employeeList.stream().map(Employee::getEmp_no).max(Comparator.comparingInt(x -> x)).orElse(-1);
            if (startEmpNo < 0) {
                logger.info("empNo<0 {}", startEmpNo);
                break;
            }
            salaryStatsDao.deleteBefore(startEmpNo);
        }
    }
}
